import onnx
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        prog='ascend_domain_utils',
        description='delete flash attention domain',
        epilog='')
    parser.add_argument('--model', help='model file location')
    parser.add_argument('--output', help='model output file location')
    args = parser.parse_args()
    model = onnx.load_model(args.model)
    model.opset_import.pop()
    for n in model.graph.node:
        if n.op_type == 'FlashAttention':
            # print(n.domain)
            n.domain = ''
        else:
            pass
            # print(n.domain)
    onnx.save_model(model, args.output)
